#!/usr/bin/env python3

import os
import queue
import random
import sys
import time
from threading import Thread

CURDIR = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(CURDIR, "helpers"))

from pure_http_client import ClickHouseClient

table_engine = sys.argv[1]

client = ClickHouseClient()

# test table without partition
client.query("DROP TABLE IF EXISTS t_async_insert_dedup_no_part SYNC")
create_query = f"""
CREATE TABLE t_async_insert_dedup_no_part (
KeyID     UInt32
) Engine = {table_engine}('/clickhouse/tables/{{shard}}/{{database}}/t_async_insert_dedup', '{{replica}}')
ORDER BY (KeyID)
"""

client.query(create_query)


client.query(
    "insert into t_async_insert_dedup_no_part values (1), (2), (3), (4), (5)",
    settings={
        "async_insert": 1,
        "wait_for_async_insert": 1,
        "insert_keeper_fault_injection_probability": 0,
    },
)
result = client.query("select count(*) from t_async_insert_dedup_no_part")
print(result, flush=True)
client.query("DROP TABLE IF EXISTS t_async_insert_dedup_no_part SYNC")


# generate data and push to queue
def generate_data(q, total_number, use_token):
    old_data = []
    max_chunk_size = 30
    partitions = ["2022-11-11 10:10:10", "2022-12-12 10:10:10"]
    last_number = 0
    while True:
        # 0 to simulate duplication
        # 1 to simulate empty
        simulate_flag = random.randint(0, 4)
        # insert old data randomly. 25% of them are dup.
        if simulate_flag == 0:
            last_idx = len(old_data) - 1
            if last_idx < 0:
                continue
            idx = last_idx - random.randint(0, 50)
            if idx < 0:
                idx = 0
            q.put(old_data[idx])
        if simulate_flag == 1:
            empty_insert_stmt = (
                "insert into t_async_insert_dedup format JSONEachRow"
            )
            q.put((empty_insert_stmt, ""))
        else:
            # insert new data.
            chunk_size = random.randint(1, max_chunk_size)
            insert_stmt = "insert into t_async_insert_dedup values "
            start = last_number + 1
            end = start + chunk_size
            if end > total_number:
                end = total_number

            token = ""
            for i in range(start, end + 1):
                partition = partitions[random.randint(0, 1)]
                insert_stmt += "('{}', {}),".format(partition, i)
                if use_token:
                    token = str(i)
            insert_stmt = insert_stmt[:-1]
            q.put((insert_stmt, token))
            old_data.append((insert_stmt, token))
            last_number = end
            if end >= total_number:
                break
    # wait all the tasks is done.
    for i in range(3):
        q.put("end")


def fetch_and_insert_data(q, client):
    while True:
        insert = q.get()
        if insert == "end":
            return
        client.query(
            insert[0],
            settings={
                "async_insert": 1,
                "async_insert_deduplicate": 1,
                "wait_for_async_insert": 1,
                "async_insert_busy_timeout_ms": 1500,
                "insert_keeper_fault_injection_probability": 0,
                "insert_deduplication_token": insert[1],
            },
            with_retries=False,
        )
        q.task_done()
        sleep_time = random.randint(50, 500)
        time.sleep(sleep_time / 1000.0)


# main process
client.query("DROP TABLE IF EXISTS t_async_insert_dedup SYNC")
create_query = f"""
CREATE TABLE t_async_insert_dedup (
EventDate DateTime,
KeyID     UInt32
) Engine = {table_engine}('/clickhouse/tables/{{shard}}/{{database}}/t_async_insert_dedup', '{{replica}}')
PARTITION BY toYYYYMM(EventDate)
ORDER BY (KeyID, EventDate) SETTINGS use_async_block_ids_cache = 1
"""

client.query(create_query)

q = queue.Queue(100)
total_number = 3000

use_token = False
if len(sys.argv) >= 3 and sys.argv[2] == "token":
    use_token = True

gen = Thread(target=generate_data, args=[q, total_number, use_token])
gen.start()

insert_threads = []
for i in range(3):
    insert = Thread(target=fetch_and_insert_data, args=[q, client])
    insert.start()
    insert_threads.append(insert)

gen.join()
for insert_thread in insert_threads:
    insert_thread.join()

retry = 0

while True:
    time.sleep(5)
    result = client.query("select KeyID from t_async_insert_dedup order by KeyID")
    result = result.split()
    err = False
    errMsg = ""
    if len(result) != total_number:
        err = True
        errMsg = f"the size of result is {len(result)}. we expect {total_number}."
    else:
        for i in range(total_number):
            expect = str(i + 1)
            real = result[i]
            if expect != real:
                err = True
                errMsg = f"error, real value {real} is not equal to expect value {expect} for {i}-th elements"
                break
    # retry several times to get stable results.
    if err and retry >= 5:
        raise Exception(errMsg)
    elif err:
        retry += 1
        continue
    break

result = client.query(
    "SELECT value FROM system.metrics where metric = 'AsyncInsertCacheSize'"
)
result = int(result.split()[0])
if result <= 0:
    raise Exception(f"AsyncInsertCacheSize should > 0, but got {result}")

result = client.query(
    "SELECT value FROM system.events where event = 'AsyncInsertCacheHits'"
)
result = int(result.split()[0])
if result <= 0:
    raise Exception(f"AsyncInsertCacheHits should > 0, but got {result}")

client.query("DROP TABLE IF EXISTS t_async_insert_dedup SYNC")

